import argparse
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import os
import numpy as np
import time
import random
import torch.backends.cudnn
import sys
sys.path.append('../..')
from iclr23code import utils, dataset, backprop, modules, surrogate, LIF


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


cuda_num = 'cuda:2'
device = torch.device(cuda_num if torch.cuda.is_available() else 'cpu')

start_time = time.time()
seed = 1028
set_seed(seed)

parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='../../data', type=str, help='Path of data')
parser.add_argument('--dataset', default='CIFAR100', type=str, help='Dataset name',
                    choices=['CIFAR10', 'CIFAR100', 'IMAGE', 'AUGDVS'])
parser.add_argument('--batch_size', default=256, type=int, help='Batch size')

parser.add_argument('--neuron', default='lif', type=str, help='Tyep of neuron',
                    choices=['lif'])
parser.add_argument('--num_step', default=5, type=int, help='Time dimension')
parser.add_argument('--bn', default='tdbn', type=str, help='Type of bn layer in model',
                    choices=['nn', 'straight', 'tdbn'])
parser.add_argument('--spike_func', default='triangle', type=str, help='Surrogate function',
                    choices=['rectangular', 'triangle', 'actan'])
parser.add_argument('--slope', default=1.0, type=float, help='Parameter used in surrogate')
parser.add_argument('--threshold', default=1.0, type=float, help='Membrane threshold')
parser.add_argument('--weak_mem', default=0.5, type=float, help='Weaken rate of mem')
parser.add_argument('--reset_mechanism', default='zero', type=str,
                    help='Membrane reset mechanism', choices=['subtract', 'zero'])

parser.add_argument('--model', default='vgg16', type=str, help='Architecture of model',
                    choices=['vgg16', 'tdresnet19c10', 'sewresnet34'])

parser.add_argument('--back_method', default='train', type=str, help='Backward method',
                    choices=['train', 'trainImg'])
parser.add_argument('--pre_process', default='logSoft', type=str, help='Pre process if result',
                    choices=['logSoft', 'soft', 'none'])
parser.add_argument('--adjust', default='bna', type=str, help='adjust class for slope',
                    choices=['d', 'bna'])
parser.add_argument('--key', default='0', type=str, help='idx of inner_list')
parser.add_argument('--momentum', default=0.0, type=float, help='Momentum in adjust class')
parser.add_argument('--limit_k', default=0.2, type=float, help='limit_k in adjust class')

parser.add_argument('--code_method', default='copy', type=str, help='Coding method',
                    choices=['copy', 'dvs'])
parser.add_argument('--updater', default='warm', type=str, help='For lr update',
                    choices=['warm', 'img'])

parser.add_argument('--lr', default=1e-1, type=float, help='Max learning rate')
parser.add_argument('--num_epoch', default=300, type=int, help='Training epodh')
parser.add_argument('--best_acc', default=0, type=float, help='Best acc of current model.')
parser.add_argument('--optimizer', default='sgd', type=str,
                    help='If use adam, set a lower lr(0.01)', choices=['adam', 'sgd'])
parser.add_argument('--appendix', default='0', type=str, help='Append information')
args = parser.parse_args()

train_data, test_data = dataset.dataload(args.data_path, args.dataset)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, drop_last=True,
                          pin_memory=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                         pin_memory=True)

if args.neuron == 'lif':
    neuron = LIF
else:
    raise ValueError("Can't find choice in args.neuron.")

if args.bn == 'nn':
    bn = modules.batch_norm2d
elif args.bn == 'straight':
    bn = modules.straight
elif args.bn == 'tdbn':
    bn = modules.td_bn
else:
    raise ValueError("Can't find choice in args.bn.")

if args.spike_func == 'rectangular':
    spike_func = surrogate.rectangular
elif args.spike_func == 'triangle':
    spike_func = surrogate.triangle
elif args.spike_func == 'actan':
    spike_func = surrogate.actan
else:
    raise ValueError("Can't find choice in args.spike_func.")

if args.model == 'vgg16':
    model = modules.NBAvgVgg16
    recorder_size = 15
elif args.model == 'tdresnet19c10':
    recorder_size = 18
    model = modules.TdResNet19C10
elif args.model == 'sewresnet34':
    print("Note: If use model `sewresnet34`, remember to use IF neuron(set weak_mem to 1.0).")
    recorder_size = 36
    model = modules.SewResNet34

else:
    raise ValueError("Can't fine choice in args.model.")

net = model(neuron, args.num_step, bn, spike_func, args.slope, args.threshold, args.weak_mem,
            args.reset_mechanism, args.dataset).to(device)

if args.back_method == 'train':
    back_method = backprop.train
elif args.back_method == 'trainImg':
    back_method = backprop.trainImg
else:
    raise ValueError("Can't find choice in args.back_method.")

if args.pre_process == 'None':
    pre_process = None
elif args.pre_process == 'logSoft':
    pre_process = nn.LogSoftmax(dim=-1)
elif args.pre_process == 'soft':
    pre_process = nn.Softmax(dim=-1)
else:
    raise ValueError("Can't find choice in args.pre_process.")

limit = 0
if args.updater == 'warm':
    updater = utils.WarmUpdate(args.num_epoch, args.lr)
elif args.updater == 'img':
    updater = utils.ImgUpdate(args.num_epoch, args.lr)
else:
    raise ValueError("Can't find choice in args.updater.")

if args.adjust == 'd':
    adjust = backprop.Adjust(limit)
elif args.adjust == 'bna':
    adjust = backprop.BNAdjust(limit, args.key, args.momentum, args.limit_k)
else:
    raise ValueError("Can't find choice in args.adjust.")

if args.code_method == 'copy':
    code_method = utils.copy_code
elif args.code_method == 'dvs':
    code_method = utils.dvs_code
else:
    raise ValueError("Can't find choice in args.code_method.")

work_dir = './result/'
os.makedirs(work_dir, exist_ok=True)
prefix_save_name = (work_dir + args.dataset + '_' + args.code_method + '_' + str(args.num_step) +
                    '_' + args.back_method + '_' + args.updater + '_' + args.pre_process + '_' +
                    args.adjust + '_' + args.key + '_' + str(args.momentum) + '_' +
                    str(args.limit_k) + '_' + args.model + '_' + args.neuron + '_' + args.bn +
                    '_' + args.spike_func + '_' + str(args.slope) + '_' + str(args.threshold) +
                    '_' + str(args.weak_mem) + '_' + args.reset_mechanism + str(limit) + '_' +
                    args.optimizer + str(args.lr) + '_' + args.appendix)
# ==============================================================================================
print(prefix_save_name)
model_save_name = prefix_save_name + '_state_dict.pth'
if os.path.exists(model_save_name):
    print('Use trained model.')
    net.load_state_dict(torch.load(model_save_name, map_location=device))
    modules.reset_model_state()
else:
    print("Training start from scratch.")

actual_epoch = 0

train_loss_save_name = prefix_save_name + '_train_loss.npy'
train_acc_save_name = prefix_save_name + '_train_acc.npy'
test_acc_save_name = prefix_save_name + '_test_acc.npy'
train_loss_recorder = []
if os.path.exists(train_loss_save_name):
    train_loss_recorder = np.load(train_loss_save_name).tolist()
train_acc_recorder = []
if os.path.exists(train_acc_save_name):
    train_acc_recorder = np.load(train_acc_save_name).tolist()
test_acc_recorder = []
if os.path.exists(test_acc_save_name):
    test_acc_recorder = np.load(test_acc_save_name).tolist()

fire_recorder = np.empty((args.num_epoch, recorder_size))
old_fire_recorder = None
train_fire_save_name = prefix_save_name + '_train_fire.npy'
if os.path.exists(train_fire_save_name):
    old_fire_recorder = np.load(train_fire_save_name)
    actual_epoch = old_fire_recorder.shape[0]

slope_recorder = np.empty((args.num_epoch, recorder_size))
old_slope_recorder = None
train_slope_save_name = prefix_save_name + '_train_slope.npy'
if os.path.exists(train_slope_save_name):
    old_slope_recorder = np.load(train_slope_save_name)

mean_recorder = np.empty((args.num_epoch, recorder_size))
old_mean_recorder = None
train_mean_save_name = prefix_save_name + '_train_mean.npy'
if os.path.exists(train_mean_save_name):
    old_mean_recorder = np.load(train_mean_save_name)

var_recorder = np.empty((args.num_epoch, recorder_size))
old_var_recorder = None
train_var_save_name = prefix_save_name + '_train_var.npy'
if os.path.exists(train_var_save_name):
    old_var_recorder = np.load(train_var_save_name)

# ======================================= Setting =========================================
best_acc = args.best_acc
best_epoch = 0

len_train_loader = len(train_loader)
len_train_data = len_train_loader * args.batch_size
len_test_data = len(test_data)
print('len_train_loader:', len_train_loader)
print('len_train_data:', len_train_data)
print('len_test_data:', len_test_data)

criterion = nn.NLLLoss()
if args.optimizer == 'sgd':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
elif args.optimizer == 'adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
    raise ValueError("Can't find choices in args.optimizer.")

# ========================================= Train ==========================================
adjust.reset()

for epoch in range(args.num_epoch):
    updater.lr_update(epoch, optimizer)
    print(optimizer)

    net.train()
    total_train_loss = 0
    total_train_cor = 0

    for idx, (data, target) in enumerate(train_loader):
        if idx == len_train_loader - 1:
            adjust.start()
        data = utils.fold(code_method(data, args.num_step)).to(device)
        target = target.to(device)
        train_loss, train_cor = back_method(net, args.num_step, data, target, optimizer,
                                            criterion, pre_process)
        total_train_loss += train_loss
        total_train_cor += train_cor
        print(f"\rEpoch: {epoch + 1}/{args.num_epoch}, Batch: {idx + 1}/{len_train_loader}, "
              f"Train loss: {train_loss}", end='')
        if idx == len_train_loader - 1:
            print()
            backprop.store_info(args.batch_size, args.num_step, epoch, fire_recorder,
                                slope_recorder, mean_recorder, var_recorder)
            actual_epoch += 1
            adjust.close()
    train_loss_recorder.append(total_train_loss)
    train_acc = 100 * total_train_cor / len_train_data
    train_acc_recorder.append(train_acc)

    net.eval()
    with torch.no_grad():
        test_cor = 0
        for (data, target) in test_loader:
            data = utils.fold(code_method(data, args.num_step)).to(device)
            target = target.to(device)
            test_cor += backprop.test(net, args.num_step, data, target)
    test_acc = 100 * test_cor / len_test_data
    test_acc_recorder.append(test_acc)

    if best_acc < test_acc:
        best_acc = test_acc
        best_epoch = epoch + 1
        torch.save(net.state_dict(), model_save_name)
    print(f"Current train acc: {train_acc}, Current test acc: {test_acc}, Best test acc: "
          f"{best_acc}")

    adjust.step(epoch)
    modules.reset_model_state()
    adjust.reset()
print(f"Best test acc: {best_acc}, Epoch: {best_epoch}")

np.save(train_loss_save_name, train_loss_recorder)
np.save(train_acc_save_name, train_acc_recorder)
np.save(test_acc_save_name, test_acc_recorder)

if old_fire_recorder is not None:
    save_fire_recorder = np.concatenate((old_fire_recorder, fire_recorder), axis=0)
else:
    save_fire_recorder = fire_recorder
np.save(train_fire_save_name, save_fire_recorder[: actual_epoch])

if old_slope_recorder is not None:
    save_slope_recorder = np.concatenate((old_slope_recorder, slope_recorder), axis=0)
else:
    save_slope_recorder = slope_recorder
np.save(train_slope_save_name, save_slope_recorder[: actual_epoch])

if old_mean_recorder is not None:
    save_mean_recorder = np.concatenate((old_mean_recorder, mean_recorder), axis=0)
else:
    save_mean_recorder = mean_recorder
np.save(train_mean_save_name, save_mean_recorder[: actual_epoch])

if old_var_recorder is not None:
    save_var_recorder = np.concatenate((old_var_recorder, var_recorder),
                                       axis=0)
else:
    save_var_recorder = var_recorder
np.save(train_var_save_name, save_var_recorder[: actual_epoch])

end_time = time.time()
print(f"Total use: {(end_time - start_time) / 60} minutes.")
